﻿#pragma kernel Project
#pragma kernel Apply

#include "ContactHandling.cginc"
#include "Integration.cginc"
#include "CollisionMaterial.cginc"
#include "Simplex.cginc"
#include "AtomicDeltas.cginc"

StructuredBuffer<int> particleIndices;
StructuredBuffer<int> simplices;
StructuredBuffer<float4> prevPositions;
StructuredBuffer<quaternion> prevOrientations;
StructuredBuffer<float> invMasses;
StructuredBuffer<float> invRotationalMasses;
StructuredBuffer<float4> principalRadii;

RWStructuredBuffer<float4> positions;
RWStructuredBuffer<quaternion> orientations;
RWStructuredBuffer<float4> deltas;

RWStructuredBuffer<contact> particleContacts;
RWStructuredBuffer<contactMasses> effectiveMasses;
StructuredBuffer<uint> dispatchBuffer;

// Variables set from the CPU
uint particleCount;
float substepTime;
float stepTime;
float sorFactor;

[numthreads(128, 1, 1)]
void Project (uint3 id : SV_DispatchThreadID) 
{
    unsigned int i = id.x;

    if (i >= dispatchBuffer[3]) return;

    int simplexSizeA;
    int simplexSizeB;
    int simplexStartA = GetSimplexStartAndSize(particleContacts[i].bodyA, simplexSizeA);
    int simplexStartB = GetSimplexStartAndSize(particleContacts[i].bodyB, simplexSizeB);

    // Combine collision materials (use material from first particle in simplex)
    collisionMaterial material = CombineCollisionMaterials(collisionMaterialIndices[simplices[simplexStartA]],
                                                           collisionMaterialIndices[simplices[simplexStartB]]);
    
    float4 prevPositionA = float4(0,0,0,0);
    float4 linearVelocityA = float4(0,0,0,0);
    float4 angularVelocityA = float4(0,0,0,0);
    float invRotationalMassA = 0;
    quaternion orientationA = quaternion(0, 0, 0, 0);
    float simplexRadiusA = 0;

    float4 prevPositionB = float4(0,0,0,0);
    float4 linearVelocityB = float4(0,0,0,0);
    float4 angularVelocityB = float4(0,0,0,0);
    float invRotationalMassB = 0;
    quaternion orientationB = quaternion(0, 0, 0, 0);
    float simplexRadiusB = 0;

    int j = 0;
    for (j = 0; j < simplexSizeA; ++j)
    {
        int particleIndex = simplices[simplexStartA + j];
        prevPositionA    += prevPositions[particleIndex] * particleContacts[i].pointA[j];
        linearVelocityA  += DifferentiateLinear(positions[particleIndex], prevPositions[particleIndex], substepTime) * particleContacts[i].pointA[j];
        angularVelocityA += DifferentiateAngular(orientations[particleIndex], prevOrientations[particleIndex], substepTime) * particleContacts[i].pointA[j];
        invRotationalMassA += invRotationalMasses[particleIndex] * particleContacts[i].pointA[j];
        orientationA += orientations[particleIndex] * particleContacts[i].pointA[j];
        simplexRadiusA += EllipsoidRadius(particleContacts[i].normal, prevOrientations[particleIndex], principalRadii[particleIndex].xyz) * particleContacts[i].pointA[j];
    }

    for (j = 0; j < simplexSizeB; ++j)
    {
        int particleIndex = simplices[simplexStartB + j];
        prevPositionB    += prevPositions[particleIndex] * particleContacts[i].pointB[j];
        linearVelocityB  += DifferentiateLinear(positions[particleIndex], prevPositions[particleIndex], substepTime) * particleContacts[i].pointB[j];
        angularVelocityB += DifferentiateAngular(orientations[particleIndex], prevOrientations[particleIndex], substepTime) * particleContacts[i].pointB[j];
        invRotationalMassB += invRotationalMasses[particleIndex] * particleContacts[i].pointB[j];
        orientationB += orientations[particleIndex] * particleContacts[i].pointB[j];
        simplexRadiusB += EllipsoidRadius(particleContacts[i].normal, prevOrientations[particleIndex], principalRadii[particleIndex].xyz) * particleContacts[i].pointB[j];
    }

    float4 rA = FLOAT4_ZERO;
    float4 rB = FLOAT4_ZERO;

    // Consider angular velocities if rolling contacts are enabled:
    if (material.rollingContacts > 0)
    {
        rA = -particleContacts[i].normal * simplexRadiusA;
        rB = particleContacts[i].normal * simplexRadiusB;

        linearVelocityA += float4(cross(angularVelocityA.xyz, rA.xyz), 0);
        linearVelocityB += float4(cross(angularVelocityB.xyz, rB.xyz), 0);
    }

    // Calculate relative velocity:
    float4 relativeVelocity = linearVelocityA - linearVelocityB;

    // Determine impulse magnitude:
    float tangentMass = effectiveMasses[i].tangentInvMassA + effectiveMasses[i].tangentInvMassB;
    float bitangentMass = effectiveMasses[i].bitangentInvMassA + effectiveMasses[i].bitangentInvMassB;
    float2 impulses = SolveFriction(particleContacts[i],tangentMass,bitangentMass,relativeVelocity, material.staticFriction, material.dynamicFriction, stepTime);

    if (abs(impulses.x) > EPSILON || abs(impulses.y) > EPSILON)
    {
        float4 tangentImpulse   = impulses.x * particleContacts[i].tangent;
        float4 bitangentImpulse = impulses.y * GetBitangent(particleContacts[i]);
        float4 totalImpulse = tangentImpulse + bitangentImpulse;

        float baryScale = BaryScale(particleContacts[i].pointA);
        for (j = 0; j < simplexSizeA; ++j)
        {
            int particleIndex = simplices[simplexStartA + j];
            float4 delta1 = (tangentImpulse * effectiveMasses[i].tangentInvMassA + bitangentImpulse * effectiveMasses[i].bitangentInvMassA) * substepTime * particleContacts[i].pointA[j] * baryScale; 
            AtomicAddPositionDelta(particleIndex, delta1);
        }

        baryScale = BaryScale(particleContacts[i].pointB);
        for (j = 0; j < simplexSizeB; ++j)
        {
            int particleIndex = simplices[simplexStartB + j];
            float4 delta2 = -(tangentImpulse * effectiveMasses[i].tangentInvMassB + bitangentImpulse * effectiveMasses[i].bitangentInvMassB) * substepTime * particleContacts[i].pointB[j] * baryScale; 
            AtomicAddPositionDelta(particleIndex, delta2);
        }

        // Rolling contacts:
        if (material.rollingContacts > 0)
        {
            float4 invInertiaTensorA = 1.0/(GetParticleInertiaTensor(simplexRadiusA, invRotationalMassA) + FLOAT4_EPSILON);
            float4 invInertiaTensorB = 1.0/(GetParticleInertiaTensor(simplexRadiusB, invRotationalMassB) + FLOAT4_EPSILON);

            // Calculate angular velocity deltas due to friction impulse:
            float4x4 solverInertiaA = TransformInertiaTensor(invInertiaTensorA, orientationA);
            float4x4 solverInertiaB = TransformInertiaTensor(invInertiaTensorB, orientationB);

            float4 angVelDeltaA = mul(solverInertiaA, float4(cross(rA.xyz, totalImpulse.xyz), 0));
            float4 angVelDeltaB = -mul(solverInertiaB, float4(cross(rB.xyz, totalImpulse.xyz), 0));

            // Final angular velocities, after adding the deltas:
            angularVelocityA += angVelDeltaA;
            angularVelocityB += angVelDeltaB;

            // Calculate weights (inverse masses):
            float invMassA = length(mul(solverInertiaA, normalizesafe(angularVelocityA)));
            float invMassB = length(mul(solverInertiaB, normalizesafe(angularVelocityB)));

            // Calculate rolling axis and angular velocity deltas:
            float4 rollAxis = FLOAT4_ZERO;
            float rollingImpulse = SolveRollingFriction(particleContacts[i],angularVelocityA, angularVelocityB, material.rollingFriction, invMassA, invMassB, rollAxis);
            angVelDeltaA += rollAxis * rollingImpulse * invMassA;
            angVelDeltaB -= rollAxis * rollingImpulse * invMassB;

            // Apply orientation deltas to particles:
            quaternion orientationDeltaA = AngularVelocityToSpinQuaternion(orientationA, angVelDeltaA, substepTime);
            quaternion orientationDeltaB = AngularVelocityToSpinQuaternion(orientationB, angVelDeltaB, substepTime);

            for (j = 0; j < simplexSizeA; ++j)
            {
                int particleIndex = simplices[simplexStartA + j];
                AtomicAddOrientationDelta(particleIndex, orientationDeltaA);
            }

            for (j = 0; j < simplexSizeB; ++j)
            {
                int particleIndex = simplices[simplexStartB + j];
                AtomicAddOrientationDelta(particleIndex, orientationDeltaB);
            }
        }
    }
}

[numthreads(128, 1, 1)]
void Apply (uint3 id : SV_DispatchThreadID) 
{
    unsigned int threadIndex = id.x;

    if (threadIndex >= particleCount) return;

    int p = particleIndices[threadIndex];
  
    ApplyPositionDelta(positions, p, sorFactor);
    ApplyOrientationDelta(orientations, p, sorFactor);
}


